iT邦幫忙

2022 iThome 鐵人賽

0
AI & Data

JAX 好好玩系列 第 38

JAX 好好玩 (38) : Flax (4) : 自訂模型

  • 分享至 

  • xImage
  •  

這篇貼文,乃是針對「第二個範例程式」中的「自訂模型」的部份,加以詳細的說明。

Flax 提供兩種方式來定義「使用者自訂模型」,一為「明確的 explict」宣告法,一為「精簡的 compact 」(或可稱為『行內的 in-line』) 宣告法。不管使用那一種方式,所有自訂的模型,都必須繼承 flax.linen.Module。習慣上大家都用 import flax.linen as nn 來載入 linen 封裝 (package), 使得大多數的範例程式中,我們看到的都是 nn.Module 這種寫法。

範例程式中,自訂模型要解決的是「cifar10 分類問題」。cifar10 資料集包括了 10 種類型的 (32 x 32 x 3) RGB 圖片,模型的輸入及輸出,將符合它的圖片規格。cifar10 資料集的重要參數定義如下:

# cifar10 資料集內共有 10 種圖形, 如下:
DS_Labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
DS_ClassNumber = len(DS_Labels)
 
# cifar10 的圖片尺寸
DS_ImageShape = (32,32,3)
DS_ImageFlatened = 32*32*3

因此所設計的模型,其輸人維度應該是 (批次, 32, 32, 3),其輸出應該是 (批次, 10)。

明確的模型宣告法

https://ithelp.ithome.com.tw/upload/images/20221101/2012961617jhe6vJ4L.png

範例中的第一個模型,是以明確方式定義的 MLP 模型。MLP 模型是以數個「密集連結層 (dense layer)」,在這裏使用了一個技巧,以 keyward parameter (即程式中的 features ) 在宣告類別案例 (instance) 的時候,指定模型的超參數。

model_expmlp = ExplicitMLP(features=[512,256,128,64,32,10])

這個例子指定了 6 層 dense layers,以及每層對應的神經元個數 (在 Flax 中,稱其為 feature 個數)。

接著,必須在 setup(self) 類別函式中,宣告這個模型所包含的子層 (或著子模型)。我們使用 6 個密集連結層為此模型的子層,Flax 提供了預先定義好的 nn.Dense(),我們直接使用它就可以了。

self.layers = [nn.Dense(feat) for feat in self.features]

如果不想寫得那麼有技巧,可以一層一層個別的宣告。

self.layer1 = nn.Dense(512)
self.layer2 = nn.Dense(256)
self.layer3 = nn.Dense(128)
self.layer4 = nn.Dense(64)
self.layer5 = nn.Dense(32)
self.layer6 = nn.Dense(10)

大家可以查閱 Flax 官方文件 [38.1],看看 Flax 已經預先定義了那些子層可供使用。

在 setup() 之後,必須宣告 __call__(self, inputs) 類別函式,描述模型從輸入到輸出的計算流程。

    def __call__(self, inputs):
        x = inputs.reshape((-1,DS_ImageFlatened))
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)      # 除了最後一層外, 其他層皆輸出 relu
        return x  # 回傳 logits

首先利用 reshape() 把輸入的圖片攤平,才能成為 Dense 子層的輸入,而後依序呼叫 6 個 Dense 子層。除了最後一層之外,每個子層皆使用 nn.relu() 作為激活函式 (activation function)。

大家可以查閱 Flax 官方文件 [38.1],看看 Flax 已經預先定義了那些子層可供使用。

精簡的模型宣告法

https://ithelp.ithome.com.tw/upload/images/20221101/20129616QrSoOLb0hx.png

精簡的方式只要宣告 __call__(self, inputs) 這個類別函式就可以了,但是要加上 @nn.compact 修飾字。和明確的宣告法一樣,我們需要__call__(self, inputs) 裏設計模型由輸入到輸出的運算流程,直接使用定義好的子層 (或先前定義好的自訂模型類別)。以這個例子來講, 我們使用了 Flax 定義好的卷積層 (nn.Conv)、最大池化層 (nn.max_pool)、及密集連結層 (nn.Dense)等等。

初始化模型

模型宣告是,指定了模型結構及運算流程,但是並沒有指定輸入資料的維度,而「初始化模型」的目的,即是在於以「虛擬輸入」來指定模型的輸入資料維度,進而決定模型最終的所需的參數數量。初始化的另一個目的,是指定模型參數的初始值,因此,在初始化時,我們傳給它 PRNG key ,以生成隨機的初始值。

初始化的目的

  1. 決定模型參數數量
  2. 指定模型參數初始值

一般來說,虛擬的輸入資料要包括批次維度,以一筆資料為一個批次即可。

# 初始化時需要 key
key = jrand.PRNGKey(3)
key, subkey1 = jrand.split(key)
 
model_cmpcnn = CompactCNN()
 
# 呼叫 init(), 並將參數保留下來
# 虛擬輸入 data : 
#  -- 要使用批次維度,但一筆資料即可。
params_cmpcnn = model_cmpcnn.init(subkey1, jnp.ones((1,32,32,3)))

init() 傳回模型的參數,要保留下來,接下來的模型訓練及模型儲存,要用到它。

Flax 的自訂模型遵循了 JAX 的精神,參數和計算分開,保持模型計算的「純粹性 pure」。

[38.1] 可以參考 flax.linen package


上一篇
JAX 好好玩 (37) : Flax (3) : 第二個範例程式
下一篇
JAX 好好玩 (39) : Flax (5) : 輔助函式及單一批次訓練函式
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言